#!/usr/bin/python
# Copyright (c) 2014 The Native Client Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import optparse
import os
import re
import StringIO
import sys

from bionic_dirs import *

NOTICE ="""/*
 * Copyright (C) 2008 The Android Open Source Project
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *  * Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in
 *    the documentation and/or other materials provided with the
 *    distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
 * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

 /*
  * This file was autogenerated by irt_syscalls.py
  */
"""

SYSCALLS_TOP = """#ifndef _IRT_SYSCALLS_H
#define _IRT_SYSCALLS_H

#include <sys/cdefs_elf.h>
#include <machine/cdefs.h>
#include <string.h>

#include <sys/types.h>
#include <sys/epoll.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <poll.h>
#include <stddef.h>
#include <fcntl.h>
#include <time.h>
#include <nacl_socket.h>
#include <nacl_stat.h>

#ifdef weak_alias
#undef weak_alias
#endif
#define weak_alias(oldname, newname)       \\
  extern typeof(oldname) newname  __attribute__ ((weak, alias (#oldname)));

#ifdef libc_hidden_def
#undef libc_hidden_def
#endif
#define libc_hidden_def(x)

struct dirent;
struct epoll_event;
struct msghdr;
struct nacl_abi_stat;
struct sockaddr;
struct timespec;
struct timeval;
struct NaClMemMappingInfo;

#include <irt.h>
#include <irt_dev.h>
#include <irt_poll.h>
#include <irt_socket.h>

typedef void (*nacl_start_func_t)(void);

__BEGIN_DECLS
typedef struct {
  size_t  (*nacl_irt_query)(const char* interface_ident, void* table, size_t tablesize);
"""

SYSCALLS_MID = """} _irt_syscalls_t;

extern _irt_syscalls_t* g_nacl_irt_syscalls_ptr;
#define _IRT_SYSCALL_PTR  g_nacl_irt_syscalls_ptr

#define __nacl_irt_query (_IRT_SYSCALL_PTR->nacl_irt_query)
"""

SYSCALLS_BOTTOM = """

int __nacl_abi_errno_to_errno(int err);
#define ENOSYS_IF_NULL(x) if ((x) == NULL) { errno = ENOSYS; return -1; }

__END_DECLS
#endif
"""


IRT_INIT_START = """
#include <elf.h>
#include <irt_syscalls.h>
#include <stdlib.h>


// Each instance of LIBC will have it's own global copy of the irt syscalls.
// When using the dyanmic loader, the loader's irt pointers will need to
// get updated seperately.

_irt_syscalls_t g_nacl_irt_syscalls_data;
_irt_syscalls_t* g_nacl_irt_syscalls_ptr = &g_nacl_irt_syscalls_data;

static int not_implemented() {
  return (38 /* ENOSYS */);
}


"""

IRT_INIT_MID = """
void __init_irt_table(TYPE_nacl_irt_query query) {
  _irt_syscalls_t* irt = g_nacl_irt_syscalls_ptr;

  if ((irt == NULL) || (query == NULL))
    return;

  memset(irt, 0, sizeof(_irt_syscalls_t));
  irt->nacl_irt_query = query;

"""

IRT_INIT_END = """}

#ifndef AT_NULL
#define AT_NULL 0
#endif

#ifndef AT_SYSINFO
#define AT_SYSINFO 32
#endif

void __init_irt_from_auxv (uintptr_t *auxv) {
  for (; *auxv != AT_NULL; auxv += 2) {
    // Get the IRT query pointer, and replace it with a pointer to the
    // IRT syscall table.
    if (*auxv == AT_SYSINFO) {
      __nacl_irt_query = (size_t (*)(const char *, void*, size_t)) auxv[1];
    }
  }
  // We will just crash in __init_irt_table due to NULL pointer access
  // if we could not find irt->nacl_irt_query. This should not happen.
  __init_irt_table(__nacl_irt_query);
}
"""


def Error(str):
  print str
  sys.exit(1)

def SplitFunction(line):
  parts = []

  left = line.find('(')
  right = line.find(')')

  parts = [line[left + 2:right].strip()]

  ret = line[:left].strip()
  args = [x.strip() for x in line[right + 2:-2].split(',')]

  parts.append([ret] + args)
  return parts


def GetFunctions(fileobj):
  functions = []
  comments = False

  while True:
    raw = fileobj.readline()
    if not raw:
      break;

    line = raw.strip()
    if not line:
      continue

    if line[:2] == '};':
      break

    parts = SplitFunction(line)
    if len(parts) != 2:
      Error('LINE: >>%s<<\n\t has %d parts, not 2 (name, spec)' %
            (raw, len(parts)))
    functions.append(parts)
  return functions


def ReadFile(filename):
  comment_match = r"(\".*?\"|\'.*?\')|(/\*.*?\*/|//[^\r\n]*$)"
  comment_regex = re.compile(comment_match, re.MULTILINE | re.DOTALL)

  with open(filename, 'r') as f:
    text = f.read()

    # Remove comments
    parts = comment_regex.split(text)
    text = ' '.join([x for x in parts if x and x[0] != '/'])

    # Unwrap lines
    text = text.replace('\\\n','')

    # Join function breaks
    lines = text.split('\n')
    text = ''

    line_count = len(lines)
    line_num = 0
    left = 0
    right = 0
    while line_num < line_count:
      l = lines[line_num].replace(' *', '* ')
      line_num += 1

      if left != right:
        l = ' ' + l.lstrip()

      left += l.count('(')
      right += l.count(')')
      if left != right:
        l = l.rstrip()
        text += l
      else:
        text += l + '\n'
    return text
  return None


def ReplaceText(src, dst):
  print 'Munging IRT %s -> %s.' % (src,dst)
  replace_map = {
    'off_t': 'int64_t',
    'native_client/src/untrusted/irt/' : '',
  }
  with open(src, 'r') as srcf:
    text = srcf.read()
    text = ReplaceText(text, [replace_map])
    with open(dst, 'w') as dstf:
      dstf.write(text)


def GetGroupName(name):
  parts = name.split('_')
  start = 0
  if parts[0] == 'DEV':
    start = 1
  out = '_'.join(parts[start:-2])
  return out


def ScanFile(filename, group_map):
  struct_str = 'struct nacl_irt_'
  struct_len = len(struct_str)
  define_str = '#define NACL_IRT_'
  define_len = len(define_str)

  text = ReadFile(filename)
  f = StringIO.StringIO(text)

  while True:
    raw = f.readline()
    if not raw:
      break;
    line = raw.strip()
    if line[:define_len] == define_str:
      define_name = 'NACL_IRT_' + line[define_len:].split()[0]
      group_name = GetGroupName(define_name)
      continue
    if line[:struct_len] == struct_str and line[-1] == '{':
      struct_name = line.split()[1]
      struct_data = GetFunctions(f)
      data = group_map.get(group_name, [])
      data.append([define_name, struct_name, struct_data])
      group_map[group_name] = data


def GetFunctionMap(group_map):
  groups = group_map.keys()
  functions = {}
  for group in groups:
    define_name, struct_name, struct_data = group_map[group][-1]
    for function in struct_data:
      name = function[0]
      spec = function[1]
      if name in functions:
        if spec != functions[name]:
          Error('Function %s with %s does not match previous spec %s.' % (
                name, spec, functions[name]))
      else:
        functions[name] = spec
  return functions


def CreateIrtHeader(filename, group_map):
  irt_fmt = '  %s (*nacl_irt_%s)(%s);\n'
  def_fmt = '#define __nacl_irt_%s (_IRT_SYSCALL_PTR->nacl_irt_%s)\n'

  functions = GetFunctionMap(group_map)
  names = sorted(functions.keys())

  with open(filename, 'w') as f:
    f.write(NOTICE)
    f.write(SYSCALLS_TOP)
    for name in names:
      spec = functions[name]
      f.write(irt_fmt % (spec[0], name, ', '.join(spec[1:])))
    f.write(SYSCALLS_MID)
    for name in names:
      spec = functions[name]
      f.write(def_fmt % (name, name))
    f.write(SYSCALLS_BOTTOM)


def WriteGroupInit(f, group_name, group_data):
  defines = []
  structs = []
  func_sets = []

  for i in range(len(group_data)):
    data = group_data[-1 - i]
    defines.append(data[0])
    structs.append(data[1])
    func_sets.append(data[2])

  f.write("static void IRT_INIT_%s(_irt_syscalls_t* irt) {\n" % group_name)
  f.write("  TYPE_nacl_irt_query query = irt->nacl_irt_query;\n")
  f.write("  struct %s funcs;\n" % structs[0]);
  f.write("  memset(&funcs, 0, sizeof(funcs));\n");
  f.write("  do {\n");
  for i, struct in enumerate(structs):
    f.write("    if (query(%s, &funcs, sizeof(struct %s))\n"
            "        == sizeof(struct %s)) break;\n" %
            (defines[i], struct, struct))
  f.write("  } while (0);\n")

  for func in func_sets[0]:
    f.write("  irt->nacl_irt_%s = funcs.%s;\n" % (func[0], func[0]))
  f.write("}\n\n")


def CreateIrtSource(filename, group_map):
  with open(filename, 'w') as f:
    f.write(NOTICE)
    f.write(IRT_INIT_START)
    groups = sorted(group_map.keys())
    for group in groups:
      WriteGroupInit(f, group, group_map[group])
    f.write(IRT_INIT_MID)
    for group in groups:
        f.write("  IRT_INIT_%s(irt);\n" % group)
    f.write(IRT_INIT_END)


def DumpGroupMap(groups):
  for group in groups:
    print 'GROUP: ' + group
    for define, struct, funcs in groups[group]:
      print '\t%s : %s' % (struct, define)
      for func in funcs:
        print '\t\t' + func[0]


def main(argv):
  parser = optparse.OptionParser()
  parser.add_option(
      '-v', '--verbose', dest='verbose',
      default=False, action='store_true',
      help='Produce more output.')
  parser.add_option(
      '-i', '--include', dest='include',
      default=os.path.join(BIONIC_SRC, 'irt', 'irt_syscalls.h'),
      help='Output include filename.')
  parser.add_option(
      '-s', '--source', dest='source',
      default=os.path.join(BIONIC_SRC, 'irt', 'irt_syscalls.c'),
      help='Output source filename.')

  options, args = parser.parse_args(argv[1:])
  groups = {}
  if not args:
    args = [
      'irt.h',
      'irt_dev.h',
    ]


  for filename in args:
    print 'Scanning IRT: ' + filename
    ScanFile(filename, groups)

  CreateIrtHeader(options.include, groups)
  CreateIrtSource(options.source, groups)
  return 0

if __name__ == '__main__':
  sys.exit(main(sys.argv))
